/**
 * Copyright 2024 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "c/ddk/graph/operator.h"
#include "c/ddk/graph/context.h"

#include <memory>
#include <string>

#include "resource_manager.h"
#include "framework/infra/log/log.h"
#include "graph/operator.h"
#include "graph/types.h"
#include "graph/op/all_ops.h"
#include "framework/graph/operator_factory.h"
#include "infra/base/assertion.h"

using namespace hiai;

vector<const char*> DynamicInputLst = {"Eltwise", "ConcatD", "Pack"};
vector<const char*> DynamicOutputLst = {"SplitD", "SplitV", "Unpack"};

static const std::map<HIAI_DataType, ge::DataType> DATA_TYPE_MAP = {
    {HIAI_DATATYPE_UINT8, ge::DataType::DT_UINT8},
    {HIAI_DATATYPE_FLOAT32, ge::DataType::DT_FLOAT},
    {HIAI_DATATYPE_FLOAT16, ge::DataType::DT_FLOAT16},
    {HIAI_DATATYPE_INT32, ge::DataType::DT_INT32},
    {HIAI_DATATYPE_INT8, ge::DataType::DT_INT8},
    {HIAI_DATATYPE_INT16, ge::DataType::DT_INT16},
    {HIAI_DATATYPE_BOOL, ge::DataType::DT_BOOL},
    {HIAI_DATATYPE_INT64, ge::DataType::DT_INT64},
    {HIAI_DATATYPE_UINT32, ge::DataType::DT_UINT32},
    {HIAI_DATATYPE_DOUBLE, ge::DataType::DT_DOUBLE}
};

OpHandle HIAI_IR_CreatePlaceHodler(ResMgrHandle resMgr, const char* opName,
    HIAI_Format format, HIAI_DataType dataType, int64_t inputShape[], uint32_t inputShapeNums)
{
    if (resMgr == nullptr || opName == nullptr || inputShape == nullptr || inputShapeNums == 0) {
        FMK_LOGE("resMgr, opName, inputShape or inputShapeNums is invalid");
        return nullptr;
    }
    auto iter = DATA_TYPE_MAP.find(dataType);
    if (iter == DATA_TYPE_MAP.end()) {
        FMK_LOGE("dataType %d is not support now", static_cast<int>(dataType));
        return nullptr;
    }

    std::vector<int64_t> s(inputShape, inputShape + inputShapeNums);
    ge::TensorDesc xDesc(ge::Shape(s), static_cast<ge::Format>(format), iter->second);

    std::shared_ptr<hiai::op::Data> xPtr = std::make_shared<hiai::op::Data>(opName);
    xPtr->update_input_desc_x(xDesc);
    BasePtr inputPtr = std::dynamic_pointer_cast<ge::Operator>(xPtr);
    auto resMgrPtr = reinterpret_cast<ResourceManager*>(resMgr);
    resMgrPtr->StoreSrcPtr(inputPtr);
    return inputPtr.get();
}

std::shared_ptr<ge::Operator> CreateSubOp(const char* opType, const char* opName)
{
    hiai::OperatorFactory* factory = hiai::OperatorFactory::Instance();
    HIAI_EXPECT_NOT_NULL_R(factory, nullptr);
    std::shared_ptr<ge::Operator> op = factory->CreateOperator(opType, opName);
    return op;
}

OpHandle HIAI_IR_CreateOp(ResMgrHandle resMgr, HIAI_IR_OpConfig* opConfig)
{
    if (resMgr == nullptr || opConfig == nullptr || opConfig->opType == nullptr || opConfig->opName == nullptr) {
        FMK_LOGE("resMgr or opConfig is invalid");
        return nullptr;
    }
    std::shared_ptr<ge::Operator> op = CreateSubOp(opConfig->opType, opConfig->opName);
    HIAI_EXPECT_NOT_NULL_R(op, nullptr);
    for (size_t i = 0; i < DynamicInputLst.size(); i++) {
        if (strcmp(opConfig->opType, DynamicInputLst[i]) == 0) {
            op->DynamicInputRegister("x", opConfig->inputNums);
        }
    }
    for (size_t i = 0; i < DynamicOutputLst.size(); i++) {
        if (strcmp(opConfig->opType, DynamicOutputLst[i]) == 0) {
            HIAI_EXPECT_TRUE_R(opConfig->outputNums != 0, nullptr);
            op->DynamicOutputRegister("y", opConfig->outputNums);
        }
    }
    auto resMgrPtr = reinterpret_cast<ResourceManager*>(resMgr);
    for (uint32_t i = 0; i < opConfig->inputNums; i++) {
        HIAI_EXPECT_NOT_NULL_R(opConfig->input[i].op, nullptr);
        BasePtr inputOpBasePtr = resMgrPtr->GetSrcPtr(opConfig->input[i].op);
        std::shared_ptr<ge::Operator> inputOpPtr = std::dynamic_pointer_cast<ge::Operator>(inputOpBasePtr);
        HIAI_EXPECT_NOT_NULL_R(inputOpPtr, nullptr);
        op->SetInput(i, *inputOpPtr, opConfig->input[i].outIndex);
    }
    for (uint32_t i = 0; i < opConfig->paramsNums; i++) {
        HIAI_EXPECT_NOT_NULL_R(opConfig->params[i].attrValue, nullptr);
        BasePtr attrBasePtr = resMgrPtr->GetSrcPtr(opConfig->params[i].attrValue);
        std::shared_ptr<ge::AttrValue> attrPtr = std::dynamic_pointer_cast<ge::AttrValue>(attrBasePtr);
        HIAI_EXPECT_NOT_NULL_R(opConfig->params[i].name, nullptr);
        HIAI_EXPECT_NOT_NULL_R(attrPtr, nullptr);
        op->SetAttr(opConfig->params[i].name, std::move(*attrPtr));
    }

    resMgrPtr->StoreSrcPtr(std::shared_ptr<ge::Base>(op));

    return op.get();
}

REGISTER_OPERATOR_CREATOR(Acos, hiai::op::Acos);
REGISTER_OPERATOR_CREATOR(Activation, hiai::op::Activation);
REGISTER_OPERATOR_CREATOR(Add, hiai::op::Add);
REGISTER_OPERATOR_CREATOR(ArgMaxExt2, hiai::op::ArgMaxExt2);
REGISTER_OPERATOR_CREATOR(Asin, hiai::op::Asin);
REGISTER_OPERATOR_CREATOR(Atan, hiai::op::Atan);
REGISTER_OPERATOR_CREATOR(AvgPoolV2, hiai::op::AvgPoolV2);
REGISTER_OPERATOR_CREATOR(BatchMatMul, hiai::op::BatchMatMul);
REGISTER_OPERATOR_CREATOR(BatchToSpaceND, hiai::op::BatchToSpaceND);
REGISTER_OPERATOR_CREATOR(BiasAdd, hiai::op::BiasAdd);
REGISTER_OPERATOR_CREATOR(BNInference, hiai::op::BNInference);
REGISTER_OPERATOR_CREATOR(BroadcastTo, hiai::op::BroadcastTo);
REGISTER_OPERATOR_CREATOR(CastT, hiai::op::CastT);
REGISTER_OPERATOR_CREATOR(Ceil, hiai::op::Ceil);
REGISTER_OPERATOR_CREATOR(ClipByValue, hiai::op::ClipByValue);
REGISTER_OPERATOR_CREATOR(ConcatD, hiai::op::ConcatD);
REGISTER_OPERATOR_CREATOR(Const, hiai::op::Const);
REGISTER_OPERATOR_CREATOR(Convolution, hiai::op::Convolution);
REGISTER_OPERATOR_CREATOR(ConvolutionDepthwise, hiai::op::ConvolutionDepthwise);
REGISTER_OPERATOR_CREATOR(ConvTranspose, hiai::op::ConvTranspose);
REGISTER_OPERATOR_CREATOR(Cos, hiai::op::Cos);
REGISTER_OPERATOR_CREATOR(Crop, hiai::op::Crop);
REGISTER_OPERATOR_CREATOR(CropAndResize, hiai::op::CropAndResize);
REGISTER_OPERATOR_CREATOR(Data, hiai::op::Data);
REGISTER_OPERATOR_CREATOR(DepthToSpace, hiai::op::DepthToSpace);
REGISTER_OPERATOR_CREATOR(DequantizeV2, hiai::op::DequantizeV2);
REGISTER_OPERATOR_CREATOR(Eltwise, hiai::op::Eltwise);
REGISTER_OPERATOR_CREATOR(Equal, hiai::op::Equal);
REGISTER_OPERATOR_CREATOR(Erf, hiai::op::Erf);
REGISTER_OPERATOR_CREATOR(Exp, hiai::op::Exp);
REGISTER_OPERATOR_CREATOR(ExpandDims, hiai::op::ExpandDims);
REGISTER_OPERATOR_CREATOR(Expm1, hiai::op::Expm1);
REGISTER_OPERATOR_CREATOR(FakeQuantWithMinMaxVars, hiai::op::FakeQuantWithMinMaxVars);
REGISTER_OPERATOR_CREATOR(Fill, hiai::op::Fill);
REGISTER_OPERATOR_CREATOR(Flatten, hiai::op::Flatten);
REGISTER_OPERATOR_CREATOR(Floor, hiai::op::Floor);
REGISTER_OPERATOR_CREATOR(FloorDiv, hiai::op::FloorDiv);
REGISTER_OPERATOR_CREATOR(FloorMod, hiai::op::FloorMod);
REGISTER_OPERATOR_CREATOR(FullyConnection, hiai::op::FullyConnection);
REGISTER_OPERATOR_CREATOR(GatherNd, hiai::op::GatherNd);
REGISTER_OPERATOR_CREATOR(GatherV2D, hiai::op::GatherV2D);
REGISTER_OPERATOR_CREATOR(GemmD, hiai::op::GemmD);
REGISTER_OPERATOR_CREATOR(Greater, hiai::op::Greater);
REGISTER_OPERATOR_CREATOR(GreaterEqual, hiai::op::GreaterEqual);
REGISTER_OPERATOR_CREATOR(GridSampler2D, hiai::op::GridSampler2D);
REGISTER_OPERATOR_CREATOR(HardSwish, hiai::op::HardSwish);
REGISTER_OPERATOR_CREATOR(InstanceNorm, hiai::op::InstanceNorm);
REGISTER_OPERATOR_CREATOR(LayerNorm, hiai::op::LayerNorm);
REGISTER_OPERATOR_CREATOR(Less, hiai::op::Less);
REGISTER_OPERATOR_CREATOR(LessEqual, hiai::op::LessEqual);
REGISTER_OPERATOR_CREATOR(Log, hiai::op::Log);
REGISTER_OPERATOR_CREATOR(Log1p, hiai::op::Log1p);
REGISTER_OPERATOR_CREATOR(LogicalAnd, hiai::op::LogicalAnd);
REGISTER_OPERATOR_CREATOR(LogicalNot, hiai::op::LogicalNot);
REGISTER_OPERATOR_CREATOR(LogicalOr, hiai::op::LogicalOr);
REGISTER_OPERATOR_CREATOR(LogicalXor, hiai::op::LogicalXor);
REGISTER_OPERATOR_CREATOR(LogSoftmax, hiai::op::LogSoftmax);
REGISTER_OPERATOR_CREATOR(MatMul, hiai::op::MatMul);
REGISTER_OPERATOR_CREATOR(Maximum, hiai::op::Maximum);
REGISTER_OPERATOR_CREATOR(Minimum, hiai::op::Minimum);
REGISTER_OPERATOR_CREATOR(MirrorPad, hiai::op::MirrorPad);
REGISTER_OPERATOR_CREATOR(Mish, hiai::op::Mish);
REGISTER_OPERATOR_CREATOR(Mul, hiai::op::Mul);
REGISTER_OPERATOR_CREATOR(Neg, hiai::op::Neg);
REGISTER_OPERATOR_CREATOR(NonMaxSuppressionV6, hiai::op::NonMaxSuppressionV6);
REGISTER_OPERATOR_CREATOR(NotEqual, hiai::op::NotEqual);
REGISTER_OPERATOR_CREATOR(OneHot, hiai::op::OneHot);
REGISTER_OPERATOR_CREATOR(Pack, hiai::op::Pack);
REGISTER_OPERATOR_CREATOR(Pad, hiai::op::Pad);
REGISTER_OPERATOR_CREATOR(Permute, hiai::op::Permute);
REGISTER_OPERATOR_CREATOR(PoolingD, hiai::op::PoolingD);
REGISTER_OPERATOR_CREATOR(Pow, hiai::op::Pow);
REGISTER_OPERATOR_CREATOR(Power, hiai::op::Power);
REGISTER_OPERATOR_CREATOR(PRelu, hiai::op::PRelu);
REGISTER_OPERATOR_CREATOR(QuantizeV2, hiai::op::QuantizeV2);
REGISTER_OPERATOR_CREATOR(Range, hiai::op::Range);
REGISTER_OPERATOR_CREATOR(Rank, hiai::op::Rank);
REGISTER_OPERATOR_CREATOR(RealDiv, hiai::op::RealDiv);
REGISTER_OPERATOR_CREATOR(Reciprocal, hiai::op::Reciprocal);
REGISTER_OPERATOR_CREATOR(ReduceLogSumExp, hiai::op::ReduceLogSumExp);
REGISTER_OPERATOR_CREATOR(ReduceMax, hiai::op::ReduceMax);
REGISTER_OPERATOR_CREATOR(ReduceMean, hiai::op::ReduceMean);
REGISTER_OPERATOR_CREATOR(ReduceMin, hiai::op::ReduceMin);
REGISTER_OPERATOR_CREATOR(ReduceProdD, hiai::op::ReduceProdD);
REGISTER_OPERATOR_CREATOR(ReduceSum, hiai::op::ReduceSum);
REGISTER_OPERATOR_CREATOR(Reshape, hiai::op::Reshape);
REGISTER_OPERATOR_CREATOR(Resize, hiai::op::Resize);
REGISTER_OPERATOR_CREATOR(ResizeBicubic, hiai::op::ResizeBicubic);
REGISTER_OPERATOR_CREATOR(ResizeBilinear, hiai::op::ResizeBilinear);
REGISTER_OPERATOR_CREATOR(ResizeNearestNeighborV2, hiai::op::ResizeNearestNeighborV2);
REGISTER_OPERATOR_CREATOR(Rint, hiai::op::Rint);
REGISTER_OPERATOR_CREATOR(ROIAlignV2, hiai::op::ROIAlignV2);
REGISTER_OPERATOR_CREATOR(Round, hiai::op::Round);
REGISTER_OPERATOR_CREATOR(Rsqrt, hiai::op::Rsqrt);
REGISTER_OPERATOR_CREATOR(Scale, hiai::op::Scale);
REGISTER_OPERATOR_CREATOR(ScatterNd, hiai::op::ScatterNd);
REGISTER_OPERATOR_CREATOR(ScatterNdUpdate, hiai::op::ScatterNdUpdate);
REGISTER_OPERATOR_CREATOR(Select, hiai::op::Select);
REGISTER_OPERATOR_CREATOR(Shape, hiai::op::Shape);
REGISTER_OPERATOR_CREATOR(Sign, hiai::op::Sign);
REGISTER_OPERATOR_CREATOR(Sin, hiai::op::Sin);
REGISTER_OPERATOR_CREATOR(Size, hiai::op::Size);
REGISTER_OPERATOR_CREATOR(Slice, hiai::op::Slice);
REGISTER_OPERATOR_CREATOR(Softmax, hiai::op::Softmax);
REGISTER_OPERATOR_CREATOR(SpaceToBatchND, hiai::op::SpaceToBatchND);
REGISTER_OPERATOR_CREATOR(SpaceToDepth, hiai::op::SpaceToDepth);
REGISTER_OPERATOR_CREATOR(SparseToDense, hiai::op::SparseToDense);
REGISTER_OPERATOR_CREATOR(SplitD, hiai::op::SplitD);
REGISTER_OPERATOR_CREATOR(SplitV, hiai::op::SplitV);
REGISTER_OPERATOR_CREATOR(Sqrt, hiai::op::Sqrt);
REGISTER_OPERATOR_CREATOR(Square, hiai::op::Square);
REGISTER_OPERATOR_CREATOR(SquaredDifference, hiai::op::SquaredDifference);
REGISTER_OPERATOR_CREATOR(Squeeze, hiai::op::Squeeze);
REGISTER_OPERATOR_CREATOR(StridedSliceV2, hiai::op::StridedSliceV2);
REGISTER_OPERATOR_CREATOR(Sub, hiai::op::Sub);
REGISTER_OPERATOR_CREATOR(Swish, hiai::op::Swish);
REGISTER_OPERATOR_CREATOR(Tan, hiai::op::Tan);
REGISTER_OPERATOR_CREATOR(Threshold, hiai::op::Threshold);
REGISTER_OPERATOR_CREATOR(Tile, hiai::op::Tile);
REGISTER_OPERATOR_CREATOR(TopK, hiai::op::TopK);
REGISTER_OPERATOR_CREATOR(TruncateDiv, hiai::op::TruncateDiv);
REGISTER_OPERATOR_CREATOR(Unpack, hiai::op::Unpack);
REGISTER_OPERATOR_CREATOR(Xlogy, hiai::op::Xlogy);

